"""
Evaluation module for cognitive maps.

This module provides the main evaluation logic for assessing:
1. Answer accuracy
2. Cognitive map validity
3. Cognitive map similarity metrics
"""

import json
from collections import defaultdict
from typing import Dict, List, Tuple, Set, Any, Optional

from .extraction import (
    extract_json_from_text as extract_cogmap_json, 
    extract_answer,
    get_setting_from_id, 
    extract_model_and_version,
    determine_answer_field
)
from .metrics import calculate_cogmap_similarity
from .io_utils import initialize_results_structure, print_results, save_json_results, save_csv_results

def determine_answer_fields(item: Dict) -> Tuple[str, str]:
    """
    Determine which fields contain the answers.
    
    Args:
        item: The evaluation item dictionary
        
    Returns:
        Tuple of (cogmap_field, plain_field)
    """
    # Check for different field names that might be used
    if 'cogmap_gen_answer' in item:
        cogmap_field = 'cogmap_gen_answer'
    elif 'cogmap_answer' in item:
        cogmap_field = 'cogmap_answer' 
    else:
        cogmap_field = 'answer'  # Default
    
    # For plain answer field
    if 'plain_answer' in item:
        plain_field = 'plain_answer'
    else:
        plain_field = 'answer'  # Default
        
    return cogmap_field, plain_field

def evaluate_answers_and_cogmaps(jsonl_path: str) -> Tuple[Dict, Dict]:
    """
    Evaluate both the accuracy of answers and the quality of generated cognitive maps.
    
    Args:
        jsonl_path: Path to the JSONL file with data to evaluate
        
    Returns:
        Tuple of (results, error_cases)
    """
    # Load data
    data = []
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    
    # Initialize results structure
    results = initialize_results_structure()
    results['total'] = len(data)
    
    # Initialize error cases tracking
    error_cases = {
        'gen_cogmap_error': [],
        'cogmap_extraction_error': []
    }

    # Accumulators for average calculations
    total_relative_position_accuracy = 0.0
    total_facing_similarity = 0.0
    total_directional_similarity = 0.0
    total_overall_similarity = 0.0
    valid_cogmap_count = 0
    
    # Counters for the filtered metrics (excluding translation)
    filtered_total = 0
    filtered_gen_cogmap_correct = 0
    filtered_valid_cogmap_count = 0
    filtered_isomorphic_count = 0
    filtered_rotation_invariant_isomorphic_count = 0
    filtered_total_relative_position_accuracy = 0.0
    filtered_total_facing_similarity = 0.0
    filtered_total_directional_similarity = 0.0
    filtered_total_overall_similarity = 0.0
    filtered_parsable_json_count = 0
    filtered_valid_format_count = 0
    
    # 跟踪设置总数
    setting_totals = defaultdict(int)

    for item in data:
        gt_answer = item.get('gt_answer')
        if not gt_answer:
            continue
            
        # Extract setting from item ID
        item_id = item.get('id', '')
        setting = get_setting_from_id(item_id)
        results['settings'][setting]['total'] += 1
        setting_totals[setting] += 1
        
        # Check if this setting should be included in overall metrics
        include_in_overall = results['settings'][setting].get('include_in_overall', True)
        if include_in_overall:
            filtered_total += 1
        
        # Determine which fields contain the answers
        cogmap_field, plain_field = determine_answer_fields(item)
        
        # Process answer with generated cognitive map
        cogmap_answer = item.get(cogmap_field, '')
        extracted_gen_cogmap = extract_answer(cogmap_answer)
        
        # Extract the cognitive maps
        generated_cogmap = None
        try:
            # First try direct extraction from the answer text
            generated_cogmap = extract_cogmap_json(cogmap_answer)
            
            # If extraction failed, check if the cognitive map is stored separately
            if not generated_cogmap:
                # Check for dedicated cognitive map fields
                if 'cognitive_map' in item:
                    cognitive_map = item.get('cognitive_map')
                    # Check if cognitive_map is string and needs extraction
                    if isinstance(cognitive_map, str):
                        generated_cogmap = extract_cogmap_json(cognitive_map)
                    else:
                        generated_cogmap = cognitive_map
                elif cogmap_field + '_cognitive_map' in item:
                    cognitive_map = item.get(cogmap_field + '_cognitive_map')
                    # Check if cognitive_map is string and needs extraction
                    if isinstance(cognitive_map, str):
                        generated_cogmap = extract_cogmap_json(cognitive_map)
                    else:
                        generated_cogmap = cognitive_map
                # Look in alternate field names
                elif 'cogmap' in item and item['cogmap'] is not None:
                    # Sometimes cogmap might be intended as the ground truth,
                    # but let's try it if we have no other options
                    if cogmap_field.startswith('gen') or cogmap_field.startswith('cogmap_gen'):
                        cognitive_map = item.get('cogmap')
                        # Check if cognitive_map is string and needs extraction
                        if isinstance(cognitive_map, str):
                            generated_cogmap = extract_cogmap_json(cognitive_map)
                        else:
                            generated_cogmap = cognitive_map
        except Exception as e:
            error_cases['cogmap_extraction_error'].append({
                'id': item_id,
                'type': 'generated',
                'response': cogmap_answer[:500],  # Truncate for readability
                'error': str(e) if isinstance(e, Exception) else "Unknown error"
            })
        
        # Get grounded cognitive map
        grounded_cogmap = item.get('cogmap')
        # Process the grounded_cogmap based on its type
        if isinstance(grounded_cogmap, str):
            grounded_cogmap = extract_cogmap_json(grounded_cogmap)
        
        # Track cases where no answer could be extracted
        if not extracted_gen_cogmap:
            error_cases['gen_cogmap_error'].append({
                'id': item_id,
                'question': item.get('question', ''),
                'gt_answer': gt_answer,
                'answer': cogmap_answer
            })
        
        # Compare with ground truth
        gen_cogmap_correct = extracted_gen_cogmap == gt_answer if extracted_gen_cogmap else False
        
        # Update overall statistics
        if gen_cogmap_correct:
            results['settings'][setting]['gen_cogmap_correct'] += 1
            if include_in_overall:
                filtered_gen_cogmap_correct += 1
        
        # Calculate cognitive map similarity if both maps are available
        similarity = None
        if generated_cogmap and grounded_cogmap:
            similarity = calculate_cogmap_similarity(generated_cogmap, grounded_cogmap)
            
            # 只有在parsable_json为真时才更新统计信息
            if similarity.get("parsable_json", False):
                results['settings'][setting]['cogmap_similarity']['parsable_json_count'] += 1
                if include_in_overall:
                    filtered_parsable_json_count += 1
                
                # 只有在valid_format为真时才更新格式有效性统计
                if similarity.get("valid_format", False):
                    results['settings'][setting]['cogmap_similarity']['valid_format_count'] += 1
                    if include_in_overall:
                        filtered_valid_format_count += 1
            
            # 只有在valid_graph为真时才计算其他图相关指标
            if similarity.get("valid_graph", False):
                results['settings'][setting]['cogmap_similarity']['total_valid'] += 1
                valid_cogmap_count += 1
                
                if include_in_overall:
                    filtered_valid_cogmap_count += 1
                
                # Update isomorphic count (backward compatibility)
                if similarity.get("isomorphic", False):
                    results['settings'][setting]['cogmap_similarity']['isomorphic_count'] += 1
                    if include_in_overall:
                        filtered_isomorphic_count += 1
                
                # Update new rotation-invariant isomorphic count
                if similarity.get("rotation_invariant_isomorphic", False):
                    results['settings'][setting]['cogmap_similarity']['rotation_invariant_isomorphic_count'] += 1
                    if include_in_overall:
                        filtered_rotation_invariant_isomorphic_count += 1
                
                # Update setting-specific similarities
                rel_pos_acc = similarity.get("relative_position_accuracy", 0.0)
                facing_sim = similarity.get("facing_similarity", 0.0)
                dir_sim = similarity.get("directional_similarity", 0.0)
                overall_sim = similarity.get("overall_similarity", 0.0)
                
                results['settings'][setting]['cogmap_similarity']['avg_relative_position_accuracy'] += rel_pos_acc
                results['settings'][setting]['cogmap_similarity']['avg_facing_similarity'] += facing_sim
                results['settings'][setting]['cogmap_similarity']['avg_directional_similarity'] += dir_sim
                results['settings'][setting]['cogmap_similarity']['avg_overall_similarity'] += overall_sim
                
                # Accumulate similarities for filtered metrics
                if include_in_overall:
                    filtered_total_relative_position_accuracy += rel_pos_acc
                    filtered_total_facing_similarity += facing_sim
                    filtered_total_directional_similarity += dir_sim
                    filtered_total_overall_similarity += overall_sim
                
                # Track rotation distribution for all settings
                if similarity.get("best_rotation") is not None:
                    rotation_name = similarity["best_rotation"].get('name', 'none')
                else:
                    rotation_name = 'none'
                # 仅当setting应该包括在总体计算中时才更新全局旋转分布
                if include_in_overall:
                    results['cogmap_similarity']['rotation_distribution'][rotation_name] += 1
                
                # 同时为每个setting单独追踪旋转分布
                if not 'rotation_distribution' in results['settings'][setting]['cogmap_similarity']:
                    results['settings'][setting]['cogmap_similarity']['rotation_distribution'] = defaultdict(int)
                results['settings'][setting]['cogmap_similarity']['rotation_distribution'][rotation_name] += 1
    
    # Calculate overall accuracy using filtered metrics (excluding translation)
    results['gen_cogmap_correct'] = filtered_gen_cogmap_correct
    results['gen_cogmap_accuracy'] = filtered_gen_cogmap_correct / filtered_total if filtered_total > 0 else 0
    
    # Calculate overall cogmap similarity metrics using filtered metrics
    results['cogmap_similarity']['total_valid'] = filtered_valid_cogmap_count
    results['cogmap_similarity']['valid_percent'] = 100 * filtered_valid_cogmap_count / filtered_total if filtered_total > 0 else 0
    results['cogmap_similarity']['isomorphic_count'] = filtered_isomorphic_count
    results['cogmap_similarity']['rotation_invariant_isomorphic_count'] = filtered_rotation_invariant_isomorphic_count
    results['cogmap_similarity']['parsable_json_count'] = filtered_parsable_json_count
    results['cogmap_similarity']['valid_format_count'] = filtered_valid_format_count
    
    # 更新总数，以便在打印或其他计算中使用过滤后的总数
    original_total = results['total']  # 保存原始总数
    results['unfiltered_total'] = original_total  # 存储原始总数以便参考
    results['total'] = filtered_total  # 更新总数为过滤后的数量
    
    if filtered_valid_cogmap_count > 0:
        results['cogmap_similarity']['avg_relative_position_accuracy'] = filtered_total_relative_position_accuracy / filtered_valid_cogmap_count
        results['cogmap_similarity']['avg_facing_similarity'] = filtered_total_facing_similarity / filtered_valid_cogmap_count
        results['cogmap_similarity']['avg_directional_similarity'] = filtered_total_directional_similarity / filtered_valid_cogmap_count
        results['cogmap_similarity']['avg_overall_similarity'] = filtered_total_overall_similarity / filtered_valid_cogmap_count
    
    # Calculate setting-specific metrics
    for setting, stats in results['settings'].items():
        stats['gen_cogmap_accuracy'] = stats['gen_cogmap_correct'] / stats['total'] if stats['total'] > 0 else 0
        
        # Calculate valid percentage for this setting
        valid_count = stats['cogmap_similarity']['total_valid']
        stats['cogmap_similarity']['valid_percent'] = 100 * valid_count / stats['total'] if stats['total'] > 0 else 0
        
        # Calculate averages for each setting
        if valid_count > 0:
            stats['cogmap_similarity']['avg_relative_position_accuracy'] /= valid_count
            stats['cogmap_similarity']['avg_facing_similarity'] /= valid_count
            stats['cogmap_similarity']['avg_directional_similarity'] /= valid_count
            stats['cogmap_similarity']['avg_overall_similarity'] /= valid_count
    
    return results, error_cases

def batch_evaluate(eval_dir: str, output_dir: str) -> Tuple[List[Dict], Dict[str, List[Dict]]]:
    """
    Evaluate all JSONL files in the given directory and save results.
    
    Args:
        eval_dir: Directory containing JSONL files for batch evaluation
        output_dir: Directory to save batch evaluation results
        
    Returns:
        Tuple of (all_results, all_settings_results)
    """
    import os
    import glob
    from datetime import datetime
    import pandas as pd
    
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Find all JSONL files in the evaluation directory
    jsonl_files = glob.glob(os.path.join(eval_dir, "*.jsonl"))
    
    # Prepare data structures for results
    all_results = []
    all_settings_results = {
        'around': [],
        'rotation': [],
        'translation': [],
        'among': [],
        'overall': []
    }
    
    # Process each file
    for jsonl_file in jsonl_files:
        filename = os.path.basename(jsonl_file)
        print(f"Evaluating {filename}...")
        
        # Extract model and version information
        model_name, version, gen_cogmap = extract_model_and_version(filename)
        
        # Skip files we can't categorize properly
        if model_name == "unknown":
            print(f"Skipping {filename} - could not determine model")
            continue
        
        # Evaluate this file
        results, error_cases = evaluate_answers_and_cogmaps(jsonl_file)
        
        # Print results to console
        print_results(results)
        
        # Create a record for the overall results
        result_record = {
            'filename': filename,
            'model': model_name,
            'version': version,
            'gen_cogmap': gen_cogmap,
            'total_examples': results['total'],  # 现在这是过滤后的总数
            'unfiltered_total': results.get('unfiltered_total', results['total']),  # 原始总数
            'gen_cogmap_accuracy': results['gen_cogmap_accuracy'],
            'valid_cogmaps': results['cogmap_similarity']['total_valid'],
            'valid_rate': results['cogmap_similarity']['valid_percent'] / 100.0 if results['cogmap_similarity']['valid_percent'] else 0.0,
            'isomorphic_count': results['cogmap_similarity']['isomorphic_count'],
            'isomorphic_rate': results['cogmap_similarity']['isomorphic_count'] / results['total'] if results['total'] > 0 else 0,
            'avg_relative_position_accuracy': results['cogmap_similarity']['avg_relative_position_accuracy'],
            'avg_facing_similarity': results['cogmap_similarity']['avg_facing_similarity'],
            'avg_directional_similarity': results['cogmap_similarity']['avg_directional_similarity'],
            'avg_overall_similarity': results['cogmap_similarity'].get('avg_overall_similarity', 0.0)
        }
        all_results.append(result_record)
        
        # Create records for each setting
        all_settings_results['overall'].append({
            'filename': filename,
            'model': model_name,
            'version': version,
            'gen_cogmap': gen_cogmap,
            'total': results['total'],  # 过滤后的总数
            'unfiltered_total': results.get('unfiltered_total', results['total']),  # 原始总数
            'gen_cogmap_accuracy': results['gen_cogmap_accuracy'],
            'valid_cogmaps': results['cogmap_similarity']['total_valid'],
            'valid_rate': results['cogmap_similarity']['valid_percent'] / 100.0 if results['cogmap_similarity']['valid_percent'] else 0.0,
            'isomorphic_count': results['cogmap_similarity']['isomorphic_count'],
            'isomorphic_rate': results['cogmap_similarity']['isomorphic_count'] / results['total'] if results['total'] > 0 else 0,
            'avg_relative_position_accuracy': results['cogmap_similarity']['avg_relative_position_accuracy'],
            'avg_facing_similarity': results['cogmap_similarity']['avg_facing_similarity'],
            'avg_directional_similarity': results['cogmap_similarity']['avg_directional_similarity'],
            'avg_overall_similarity': results['cogmap_similarity'].get('avg_overall_similarity', 0.0)
        })
        
        # Add settings-specific results (including translation, which is reported separately)
        for setting in ['around', 'rotation', 'translation', 'among']:
            setting_stats = results['settings'][setting]
            if setting_stats['total'] > 0:
                valid_count = setting_stats['cogmap_similarity']['total_valid']
                include_info = "(included in overall)" if setting_stats.get('include_in_overall', True) else "(reported separately)"
                all_settings_results[setting].append({
                    'filename': filename,
                    'model': model_name,
                    'version': version,
                    'gen_cogmap': gen_cogmap,
                    'include_in_overall': setting_stats.get('include_in_overall', True),
                    'include_status': include_info,
                    'total': setting_stats['total'],
                    'gen_cogmap_accuracy': setting_stats['gen_cogmap_accuracy'],
                    'valid_cogmaps': valid_count,
                    'valid_rate': setting_stats['cogmap_similarity']['valid_percent'] / 100.0 if setting_stats['cogmap_similarity']['valid_percent'] else 0.0,
                    'isomorphic_count': setting_stats['cogmap_similarity']['isomorphic_count'],
                    'isomorphic_rate': setting_stats['cogmap_similarity']['isomorphic_count'] / setting_stats['total'] if setting_stats['total'] > 0 else 0,
                    'avg_relative_position_accuracy': setting_stats['cogmap_similarity']['avg_relative_position_accuracy'],
                    'avg_facing_similarity': setting_stats['cogmap_similarity']['avg_facing_similarity'],
                    'avg_directional_similarity': setting_stats['cogmap_similarity']['avg_directional_similarity'],
                    'avg_overall_similarity': setting_stats['cogmap_similarity'].get('avg_overall_similarity', 0.0)
                })
        
        # Save individual result
        output_file = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_results.json")
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump({
                'filename': filename,
                'model': model_name,
                'version': version,
                'gen_cogmap': gen_cogmap,
                'results': results,
                'error_cases': error_cases
            }, f, indent=2)
    
    # Save all results as JSON
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    all_results_file = os.path.join(output_dir, f"all_results_{timestamp}.json")
    with open(all_results_file, 'w', encoding='utf-8') as f:
        json.dump(all_results, f, indent=2)
    
    # Save all settings results as JSON
    all_settings_file = os.path.join(output_dir, f"all_settings_results_{timestamp}.json")
    with open(all_settings_file, 'w', encoding='utf-8') as f:
        json.dump(all_settings_results, f, indent=2)
    
    # Create DataFrames from results for easier analysis and save as CSV
    df_results = pd.DataFrame(all_results)
    df_results_file = os.path.join(output_dir, f"all_results_{timestamp}.csv")
    df_results.to_csv(df_results_file, index=False)
    
    # Create DataFrames for each setting
    for setting, setting_results in all_settings_results.items():
        if setting_results:
            df_setting = pd.DataFrame(setting_results)
            df_setting_file = os.path.join(output_dir, f"{setting}_results_{timestamp}.csv")
            df_setting.to_csv(df_setting_file, index=False)
    
    print(f"\nBatch evaluation complete. Results saved to {output_dir}")
    print(f"Summary results: {all_results_file}")
    print(f"CSV results: {df_results_file}")
    
    return all_results, all_settings_results 